package util

import (
	"context"
	"database/sql"
	"fmt"
	"math"
	"strings"
	"sync"
	"time"
)

type MysqlClient struct {
	Listener      func(ctx context.Context, query string, err error, args ...any)
	dbMap         map[string]*sql.DB
	version       string
	ConnOptions   []*MysqlConnOption
	TimeoutInsert time.Duration
	TimeoutSelect time.Duration
	TimeoutUpdate time.Duration
	TimeoutTrans  time.Duration
	mutex         sync.RWMutex
}

func NewMysqlClient(connOptions ...*MysqlConnOption) (*MysqlClient, error) {
	if len(connOptions) == 0 {
		return nil, ErrMysqlWrongOption
	}
	return &MysqlClient{
		TimeoutInsert: 3 * time.Second,
		TimeoutSelect: 3 * time.Second,
		TimeoutUpdate: 3 * time.Second,
		TimeoutTrans:  5 * time.Second,
		ConnOptions:   connOptions,
		dbMap:         make(map[string]*sql.DB),
	}, nil
}

func NewMysqlClientEasy(dsn ...string) (*MysqlClient, error) {
	if len(dsn) == 0 {
		return nil, ErrMysqlWrongDsn
	}
	options := make([]*MysqlConnOption, 0, len(dsn))
	for _, d := range dsn {
		if option, err := NewMysqlConnOption(d, 0, 0, -1, -1); err != nil {
			return nil, err
		} else {
			options = append(options, option)
		}
	}
	return NewMysqlClient(options...)
}

func (c *MysqlClient) GetDB(ctx context.Context, useSlave bool) (*sql.DB, error) {
	if useSlave && len(c.ConnOptions) == 1 {
		useSlave = false
	}
	var option *MysqlConnOption
	var err error
	if !useSlave {
		option = c.ConnOptions[0]
	} else if len(c.ConnOptions) == 2 {
		option = c.ConnOptions[1]
	} else {
		option = c.ConnOptions[RandRange(1, len(c.ConnOptions))]
	}
	if option == nil || option.Addr == "" || option.DbName == "" || option.Dsn == "" {
		return nil, ErrMysqlWrongOption
	}
	k := option.Addr + ":" + option.DbName
	c.mutex.RLock()
	db, ok := c.dbMap[k]
	c.mutex.RUnlock()
	if ok && db != nil {
		return db, nil
	}
	c.mutex.Lock()
	defer c.mutex.Unlock()
	if db, ok = c.dbMap[k]; ok && db != nil {
		return db, nil
	}
	if db, err = sql.Open("mysql", option.Dsn); err != nil {
		return nil, err
	}
	ctx1, cancel := context.WithTimeout(ctx, c.TimeoutSelect)
	defer cancel()
	if err = db.PingContext(ctx1); err != nil {
		return nil, err
	}
	db.SetMaxIdleConns(option.MaxIdle)
	db.SetMaxOpenConns(option.MaxOpen)
	db.SetConnMaxIdleTime(option.MaxIdleTime)
	db.SetConnMaxLifetime(option.MaxLifetime)
	if !ok {
		c.dbMap[k] = db
	}
	return db, nil
}

func (c *MysqlClient) CloseDB() {
	c.mutex.Lock()
	defer c.mutex.Unlock()
	for _, db := range c.dbMap {
		_ = db.Close()
	}
	clear(c.dbMap)
}

func (c *MysqlClient) Truncate(ctx context.Context, table string) error {
	db, err := c.GetDB(ctx, false)
	if err != nil {
		return err
	}
	query := fmt.Sprintf("TRUNCATE TABLE `%s`", table)
	defer func() {
		if c.Listener != nil {
			go c.Listener(ctx, query, err, table)
		}
	}()
	ctx1, cancel := context.WithTimeout(ctx, c.TimeoutUpdate)
	defer cancel()
	_, err = db.ExecContext(ctx1, query)
	return err
}

func (c *MysqlClient) Version(ctx context.Context) (string, error) {
	if c.version != "" {
		return c.version, nil
	}
	if err := c.SelectWalk(ctx, func(ctx1 context.Context, row *MysqlRow) error {
		if c.version == "" {
			c.version = row.ToStr("ver")
		}
		return nil
	}, "SELECT VERSION() AS `ver`"); err != nil {
		return "", err
	}
	return c.version, nil
}

func (c *MysqlClient) doInsert(ctx context.Context, table string, row *MysqlRow, ignore, replace bool) (ret int64, err error) {
	if row.IsEmpty() {
		return 0, ErrMysqlEmptyData
	}
	cols := make([]string, 0)
	phds := make([]string, 0)
	params := make([]any, 0)
	for col, param := range *row {
		if strings.Contains(col, "`") {
			return 0, ErrMysqlColWithBQuote
		}
		cols = append(cols, col)
		params = append(params, param)
		phds = append(phds, "?")
	}
	query := ""
	if ignore {
		query = "INSERT IGNORE "
	} else if replace {
		query = "REPLACE "
	} else {
		query = "INSERT "
	}
	query += "INTO `" + table + "` (`" + strings.Join(cols, "`, `") + "`) VALUES (" + strings.Join(phds, ", ") + ")"
	defer func() {
		if c.Listener != nil {
			go c.Listener(ctx, query, err, ret, params)
		}
	}()
	var db *sql.DB
	db, err = c.GetDB(ctx, false)
	if err != nil {
		return
	}
	ctx1, cancel := context.WithTimeout(ctx, c.TimeoutInsert)
	defer cancel()
	var res sql.Result
	if res, err = db.ExecContext(ctx1, query, params...); err != nil {
		return
	}
	ret, err = res.LastInsertId()
	return
}

func (c *MysqlClient) Insert(ctx context.Context, table string, row *MysqlRow) (ret int64, err error) {
	return c.doInsert(ctx, table, row, false, false)
}

func (c *MysqlClient) InsertIgnore(ctx context.Context, table string, row *MysqlRow) (ret int64, err error) {
	return c.doInsert(ctx, table, row, true, false)
}

func (c *MysqlClient) InsertReplace(ctx context.Context, table string, row *MysqlRow) (ret int64, err error) {
	return c.doInsert(ctx, table, row, false, true)
}

func (c *MysqlClient) InsertDuplicate(ctx context.Context, table string, row *MysqlRow) (ret int64, err error) {
	if row.IsEmpty() {
		return 0, ErrMysqlEmptyData
	}
	cols := make([]string, 0)
	phds := make([]string, 0)
	sets := make([]string, 0)
	params := make([]any, 0)
	for col, param := range *row {
		if strings.Contains(col, "`") {
			return 0, ErrMysqlColWithBQuote
		}
		cols = append(cols, col)
		params = append(params, param)
		phds = append(phds, "?")
		sets = append(sets, "`"+col+"` = ?")
	}
	params = append(params, params...)
	query := "INSERT INTO `" + table + "` (`" + strings.Join(cols, "`, `") + "`) VALUES (" + strings.Join(phds, ", ") + ") "
	query += "ON DUPLICATE KEY UPDATE " + strings.Join(sets, ", ")
	defer func() {
		if c.Listener != nil {
			go c.Listener(ctx, query, err, ret, params)
		}
	}()
	var db *sql.DB
	db, err = c.GetDB(ctx, false)
	if err != nil {
		return
	}
	ctx1, cancel := context.WithTimeout(ctx, c.TimeoutInsert)
	defer cancel()
	var res sql.Result
	if res, err = db.ExecContext(ctx1, query, params...); err != nil {
		return
	}
	ret, err = res.LastInsertId()
	return
}

func (c *MysqlClient) InsertBatch(ctx context.Context, table string, rows []*MysqlRow) (ret int64, err error) {
	if len(rows) == 0 || rows[0].IsEmpty() {
		return 0, ErrMysqlEmptyData
	}
	cols := make([]string, 0)
	phds := make([]string, 0)
	colsReal := make([]string, 0)
	for col := range *(rows[0]) {
		if strings.Contains(col, "`") {
			return 0, ErrMysqlColWithBQuote
		}
		cols = append(cols, col)
		colsReal = append(colsReal, col)
		phds = append(phds, "?")
	}
	query := "INSERT IGNORE INTO `" + table + "` (`" + strings.Join(cols, "`, `") + "`) VALUES (" + strings.Join(phds, ", ") + ")"
	var db *sql.DB
	db, err = c.GetDB(ctx, false)
	if err != nil {
		return
	}
	ctx1, cancel := context.WithTimeout(ctx, 3*c.TimeoutInsert)
	defer cancel()
	var stmt *sql.Stmt
	stmt, err = db.PrepareContext(ctx1, query)
	if err != nil {
		return
	}
	defer func() {
		_ = stmt.Close()
	}()
	var res sql.Result
	var affected, newId int64
	for _, row := range rows {
		params := make([]any, 0, len(colsReal))
		for _, col := range colsReal {
			params = append(params, row.Get(col))
		}
		if res, err = stmt.ExecContext(ctx1, params...); err != nil {
			return
		} else if affected, err = res.RowsAffected(); err != nil {
			return
		} else {
			ret += affected
			if c.Listener != nil {
				newId, err = res.LastInsertId()
				go c.Listener(ctx, query, err, newId, params)
			}
		}
	}
	return
}

func (c *MysqlClient) SelectWalk(
	ctx context.Context,
	fn func(ctx context.Context, row *MysqlRow) error,
	query string,
	params ...any,
) error {
	if (strings.ToLower(query[0:7]) != "select " && strings.ToLower(query[0:5]) != "show ") || len(query) < 8 {
		return ErrMysqlWrongSql
	}
	db, err := c.GetDB(ctx, true)
	if err != nil {
		return err
	}
	ctx1, cancel := context.WithTimeout(ctx, c.TimeoutSelect)
	defer cancel()
	var rows *sql.Rows
	rows, err = db.QueryContext(ctx1, query, params...)
	if err != nil {
		return err
	}
	defer func() {
		if c.Listener != nil {
			go c.Listener(ctx, query, err, params)
		}
	}()
	defer func() {
		_ = rows.Close()
	}()
	var cols []string
	cols, err = rows.Columns()
	if err != nil {
		return err
	}
	for rows.Next() {
		rawBuffers := make([]sql.RawBytes, len(cols))
		scanArgs := make([]any, len(cols))
		for i := range rawBuffers {
			scanArgs[i] = &rawBuffers[i]
		}
		if err = rows.Scan(scanArgs...); err != nil {
			return err
		}
		row := NewMysqlRow()
		for i, bs := range rawBuffers {
			row.Set(cols[i], string(bs))
		}
		if err = fn(ctx, row); err != nil {
			return err
		}
	}
	return nil
}

func (c *MysqlClient) ShowColumns(ctx context.Context, table string) (map[string]string, error) {
	ret := make(map[string]string)
	if err := c.SelectWalk(ctx, func(_ context.Context, row *MysqlRow) error {
		ret[row.ToStr("Field")] = row.ToStr("Type")
		return nil
	}, fmt.Sprintf("SHOW COLUMNS FROM `%s`", table)); err != nil {
		return nil, err
	}
	return ret, nil
}

func (c *MysqlClient) Select(ctx context.Context, query string, params ...any) (rows []*MysqlRow, err error) {
	err = c.SelectWalk(ctx, func(_ context.Context, row *MysqlRow) error {
		rows = append(rows, row)
		return nil
	}, query, params...)
	return
}

func (c *MysqlClient) Count(ctx context.Context, query string, params ...any) (cnt int64, err error) {
	if strings.ToLower(query[0:13]) != "select count(" || len(query) < 14 {
		return 0, ErrMysqlWrongSql
	}
	db, err := c.GetDB(ctx, true)
	if err != nil {
		return 0, err
	}
	defer func() {
		if c.Listener != nil {
			go c.Listener(ctx, query, err, cnt, params)
		}
	}()
	ctx1, cancel := context.WithTimeout(ctx, c.TimeoutSelect)
	defer cancel()
	err = db.QueryRowContext(ctx1, query, params...).Scan(&cnt)
	return
}

func (c *MysqlClient) SelectPage(
	ctx context.Context,
	fn func(ctx context.Context, row *MysqlRow) error,
	page, size int64,
	table, where, order, cols string,
	params ...any,
) (totalRows, totalPages, currentPage int64, err error) {
	if where == "" {
		where = "1 = 1"
	}
	totalRows, err = c.Count(ctx, fmt.Sprintf("SELECT COUNT(*) FROM `%s` WHERE %s", table, where), params...)
	if err != nil {
		return
	}
	if totalRows == 0 {
		return
	}
	if size < 1 {
		size = 10
	}
	totalPages = int64(math.Ceil(float64(totalRows) / float64(size)))
	if page < 1 {
		page = 1
	}
	currentPage = page
	if page > totalPages {
		err = ErrMysqlPageTooLarge
		return
	}
	if cols == "" {
		cols = "*"
	}
	if order != "" {
		order = "ORDER BY " + order
	}
	params = append(params, (page-1)*size, size)
	query := fmt.Sprintf("SELECT %s FROM `%s` WHERE %s %s LIMIT ?, ?", cols, table, where, order)
	err = c.SelectWalk(ctx, fn, query, params...)
	return
}

func (c *MysqlClient) SelectByIds(
	ctx context.Context,
	fn func(ctx context.Context, row *MysqlRow) error,
	table, cols string,
	ids ...any,
) error {
	if cols == "" {
		cols = "*"
	}
	query := ""
	switch len(ids) {
	case 0:
		return ErrMysqlEmptyData
	case 1:
		query = fmt.Sprintf("SELECT %s FROM `%s` WHERE `id` = ?", cols, table)
	default:
		phds := strings.TrimRight(strings.Repeat(" ?,", len(ids)), ",")
		query = fmt.Sprintf("SELECT %s FROM `%s` WHERE `id` IN (%s)", cols, table, phds)
	}
	return c.SelectWalk(ctx, fn, query, ids...)
}

func (c *MysqlClient) Update(ctx context.Context, query string, params ...any) (ret int64, err error) {
	if sqlStart7 := strings.ToLower(query[0:7]); (sqlStart7 != "update " && sqlStart7 != "delete ") || len(query) < 8 {
		err = ErrMysqlWrongSql
		return
	}
	var db *sql.DB
	db, err = c.GetDB(ctx, false)
	if err != nil {
		return
	}
	defer func() {
		if c.Listener != nil {
			go c.Listener(ctx, query, err, ret, params)
		}
	}()
	ctx1, cancel := context.WithTimeout(ctx, c.TimeoutUpdate)
	defer cancel()
	var res sql.Result
	if res, err = db.ExecContext(ctx1, query, params...); err != nil {
		return
	}
	ret, err = res.RowsAffected()
	return
}

func (c *MysqlClient) UpdateById(ctx context.Context, table string, row *MysqlRow, id any) (int64, error) {
	if row.IsEmpty() {
		return 0, ErrMysqlEmptyData
	}
	row.Drop("id")
	sets := make([]string, 0)
	params := make([]any, 0)
	for col, param := range *row {
		if strings.Contains(col, "`") {
			return 0, ErrMysqlColWithBQuote
		}
		sets = append(sets, "`"+col+"` = ?")
		params = append(params, param)
	}
	if len(sets) == 0 {
		return 0, ErrMysqlEmptyData
	}
	params = append(params, id)
	return c.Update(ctx, fmt.Sprintf("UPDATE `%s` SET %s WHERE `id` = ?", table, strings.Join(sets, ", ")), params...)
}

func (c *MysqlClient) Delete(ctx context.Context, query string, params ...any) (int64, error) {
	if strings.ToLower(query[0:7]) != "delete " || len(query) < 8 {
		return 0, ErrMysqlWrongSql
	}
	return c.Update(ctx, query, params...)
}

func (c *MysqlClient) DeleteByIds(ctx context.Context, table string, ids ...any) (int64, error) {
	query := ""
	switch len(ids) {
	case 0:
		return 0, ErrMysqlEmptyData
	case 1:
		query = fmt.Sprintf("DELETE FROM `%s` WHERE `id` = ?", table)
	default:
		phds := strings.TrimRight(strings.Repeat(" ?,", len(ids)), ",")
		query = fmt.Sprintf("DELETE FROM `%s` WHERE `id` IN (%s)", table, phds)
	}
	return c.Delete(ctx, query, ids...)
}

func (c *MysqlClient) Transaction(
	ctx context.Context,
	fn func(ctx context.Context, tx *sql.Tx) error,
	opts *sql.TxOptions,
) error {
	db, err := c.GetDB(ctx, false)
	if err != nil {
		return err
	}
	ctx1, cancel := context.WithTimeout(ctx, c.TimeoutTrans)
	defer cancel()
	tx, err := db.BeginTx(ctx1, opts)
	if err != nil {
		return err
	}
	defer func() {
		if r := recover(); r != nil {
			_ = tx.Rollback()
		}
	}()
	if err = fn(ctx, tx); err != nil {
		_ = tx.Rollback()
	} else if err = tx.Commit(); err != nil {
		_ = tx.Rollback()
	}
	return err
}
